```{r boilerplate}
library(tidyverse)
knitr::opts_chunk$set(dev = "cairo_pdf")
```



```{r }

res <- bind_rows(
  read_csv("results/asym_sim_1d_100n_14-Jun-2022-15-21-50.csv"),
  read_csv("results/asym_sim_1d_500n_14-Jun-2022-15-24-23.csv"),
  read_csv("results/asym_sim_1d_1000n_14-Jun-2022-15-26-48.csv"),
  read_csv("results/asym_sim_1d_5000n_14-Jun-2022-15-30-53.csv"),
  read_csv("results/asym_sim_1d_10000n_14-Jun-2022-15-36-55.csv")
)



```

```{r sim_true_regret, fig.width = 4.5, fig.height = 4.5}
ubs <- unique(res$ub)
res %>%
  filter(ub %in% ubs[-1][-(length(ubs) - 1)]) %>%
  group_by(ub, n) %>%
  summarise(across(, c(mean = mean, sd = sd))) %>%
  ungroup() %>%
  select(contains("regret"), ub, n) %>%
  pivot_longer(-c(ub, n)) %>%
  mutate(policy = sapply(str_split(name, "_"), `[`, 1),
         comparison = sapply(str_split(name, "_"), function(x) head(tail(x, n = 2), n  = 1)),
         worst_case = str_detect(name, "max"),
         val  = sapply(str_split(name, "_"), tail, 1)) %>%
  select(-name) %>%
  pivot_wider(names_from = val, values_from = value) %>%
  rename(value = mean) %>%
  filter(policy == "est.oracle" & comparison == "oracle" |
         policy == "est" & comparison == "0" & ub < 1 |
         policy == "est" & comparison == "1" & ub > 1,
         !worst_case) %>%
  filter(comparison == "oracle") %>%
  mutate(comparison = case_when(
    comparison == "oracle" ~ "Oracle Policy",
    comparison == "0" ~ "Never Treat Policy",
    comparison == "1" ~ "Always Treat Policy"
  )) %>%
  ggplot(aes(x = n, y = value, color = ub, group = ub)) +
  geom_line() +
  geom_hline(yintercept = 0, lty = 2) +
  scale_color_gradient2(expression(u[l]), midpoint = 1) +
  scale_fill_gradient2(expression(u[l]), midpoint = 1) +
  guides(color = guide_colorbar(title.position = "top", barwidth = 10, barheight = .5, title.align = "c")) +
  ylim(c(0, 0.025)) +
  ylab("True regret relative to the oracle policy") +
  xlab("Sample size") +
  theme_bw() +
  theme(legend.position = "bottom")

```

```{r sim_mis_class, fig.width = 4.5, fig.height = 4.5}


nice_names <-  c("constant_policy_error"="Never Treat Policy",
                "pos_class_error"="More likely in (1,1) stratum",
                "pos_tau_error"="Positive CATE")
res %>%
  filter(ub == ubs[4]) %>%
  group_by(n) %>%
  summarise(across(contains("error"), mean)) %>%
  pivot_longer(-n, names_to="classifier", values_to="error") %>%
  mutate(classifier = nice_names[classifier]) %>%
  ggplot(aes(x = n, y = error)) +
  geom_line() +
  geom_hline(yintercept = 0, lty = 2) +
  facet_wrap(~ classifier, ncol = 1, scales = "free_y") +
  xlab("Sample size") +
  scale_y_continuous("Misclassification rate", labels = scales::percent) +
  theme_bw()

```